IntroΒΆ

Neural style transfer is the process of generating a new image that combines the content from one image and the style(s) from another. It is an ill-posed problem, as there is no single correct output for style transfer, and traditional supervised learning algorithms cannot be readily applied due to the requirement of a pair of input images, which is impractical. I propose to address this problem by modifying existing methods, including:

  1. A Neural Algorithm of Artistic Style by Gatys et al.
  2. A learned representation for Artistic Style by Dumoulin et al.

A Neural Algorithm of Artistic StyleΒΆ

Gatys et al. published one of the first seminal works to use neural networks to solve this problem. The most important idea here is "representations of content and style in the Convolutional Neural Network are separable". They use the following definitions to remove ambiguity from the problem.

  1. Two images are similar in content if their high-level features as extracted by a trained classifier are close in Euclidian distance.
  2. Two images are similar in style if their low-level features as extracted by a trained classifier share the same statistics or, more concretely, if the difference between the features’ Gram matrices has a small Frobenius norm.

Feature correlations among the feature maps are given by the Gram matrix, i.e for a layer $l$, $G_{i j}^l$ is the inner product between the vectorized feature maps $i$ and $j$ $$ G_{i j}^l=\sum_k F_{i k}^l F_{j k}^l $$

The objective is to minimize the weighted sum of style and content losses. The desired stylized image is found by using gradient descent to find the minima.
Given content image $\vec{p}$ and style image $\vec{a}$, target image $\vec{x}$, the loss function is

$$\mathcal{L}_{\text {total }}(\vec{p}, \vec{a}, \vec{x})=\alpha \mathcal{L}_{\text {content }}(\vec{p}, \vec{x})+\beta \mathcal{L}_{\text {style }}(\vec{a}, \vec{x})$$

LimitationsΒΆ

  1. prohibitively expensive espectially for high resolution images - task is modelled as an optimization problem, requiring hundreds or thousands of iterations.
  2. Haven't provided much evidence for applicability in multi-style transfer extension

Despite its flaws, the algorithm is very flexible. Therefore, I decided to first explore this method to solve this problem.

Modified AlgorithmΒΆ

I just replace the the style loss with weighted loss terms corresponding to each style.

$$\mathcal{L}_{\text {total }}(\vec{p}, \vec{a}, \vec{x})=\alpha \mathcal{L}_{\text {content }}(\vec{p}, \vec{x})+\sum_{i=1}^{N}\beta_i \mathcal{L}_{\text {style}_i}(\vec{a}_i, \vec{x})$$

Pretrained model: VGG19ΒΆ

VGG19

VGG19, pretrained on IMAGENET, is used as a feature extractor. It achieved SOTA during its time. We are only concerned with features and not classification, so the fully connected layers are safely removed. The output of the individual convolution layers will be used to find the content and style losses.

The object information becomes increasingly explicit along the processing hierarchy. Detailed pixel information is lost while the high-level content of the image is preserved. Therefore, the conv_4 or conv_5 filters can be used to find the content loss.

The style representation is computed using the correlations between the different features in different layers of the model, where the expectation is taken over the spatial extend of the input image. The texture/style is the statistical relationship between the pixels of a source image, which is assumed to have a stationary distribution at some scale. Therefore, all the conv outputs can be used to compute style loss.

NoteΒΆ

I found a relevant tutorial Neural Transfer Using PyTorch. I am investing my efforts only on the interesting and important parts, so I have used most of the boiler plate code from this tutorial. I have made a few changes to it to make it applicable for multi-style.

# find  loss for each style

if name in style_layers:
    style_loss = []
    for index, style_image in enumerate(style_imgs):
        target_feature = model(style_img).detach()
        style_loss_index = StyleLoss(target_feature)
        model.add_module("style_loss_{}_{}".format(i, index), style_loss_index)
        style_loss.append(style_loss_index)
    style_losses.append(style_loss)

# combine together using weight 

for sl in style_losses:
    for index, style_loss in enumerate(sl):
        style_score += style_coeffs[index] * style_loss.loss

InΒ [13]:
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
from torchvision.models import vgg19, VGG19_Weights

import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)


from icecream import ic
# desired size of the output image
imsize = 512 if torch.cuda.is_available() else 128  # use small size if no GPU

loader = transforms.Compose([                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
    transforms.Resize((imsize, imsize)),  # scale imported image
    transforms.ToTensor()])  # transform it into a torch tensor


def image_loader(image_name):
    image = Image.open(image_name)
    # fake batch dimension required to fit network's input dimensions
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)


style_img_1 = image_loader("images/vanGogh.jpg")
style_img_2 = image_loader("images/daVinci.jpg")

content_path = "images/Masterlayer_Event221_SetA.png"
original_dim = Image.open(content_path).size
content_img = image_loader(content_path)

assert style_img_1.size() == content_img.size() == style_img_2.size(), \
    "we need to import style and content images of the same size"

unloader = transforms.ToPILImage()  # reconvert into PIL image

def get_PIL_image(tensor):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)      # remove the fake batch dimension
    image = unloader(image)
    return image

fig, axs = plt.subplots(1, 3, figsize=(20, 10))
images = [style_img_1, style_img_2, content_img]
titles = ["Style Image 1", "Style Image 2", "Content Image"]
for i in range(3):
    img = get_PIL_image(images[i])
    axs[i].imshow(img, interpolation='nearest')
    axs[i].set_title(titles[i])
    axs[i].axis('off')
No description has been provided for this image

Now the style loss module looks almost exactly like the content loss module. The style distance is also computed using the mean square error between $G_{XL}$ and $G_{SL}$.

InΒ [16]:
class StyleLoss(nn.Module):

    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input
InΒ [17]:
cnn = vgg19(weights=VGG19_Weights.DEFAULT).features.eval()
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /home2/aman.atman/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
4.3%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

11.9%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

19.5%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

27.3%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

34.9%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

42.9%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

50.9%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

58.5%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

66.0%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

73.7%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

81.5%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

89.0%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

97.1%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Additionally, VGG networks are trained on images with each channel normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]. We will use them to normalize the image before sending it into the network.

InΒ [18]:
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])

# create a module to normalize input image so we can easily put it in a
# ``nn.Sequential``
class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        # .view the mean and std to make them [C x 1 x 1] so that they can
        # directly work with image Tensor of shape [B x C x H x W].
        # B is batch size. C is number of channels. H is height and W is width.
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, img):
        # normalize ``img``
        return (img - self.mean) / self.std
InΒ [19]:
# desired depth layers to compute style/content losses :
content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
                               style_imgs, content_img,
                               content_layers=content_layers_default,
                               style_layers=style_layers_default):
    # normalization module
    normalization = Normalization(normalization_mean, normalization_std)

    # just in order to have an iterable access to or list of content/style
    # losses
    content_losses = []
    style_losses = []

    # assuming that ``cnn`` is a ``nn.Sequential``, so we make a new ``nn.Sequential``
    # to put in modules that are supposed to be activated sequentially
    model = nn.Sequential(normalization)

    i = 0  # increment every time we see a conv
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            # The in-place version doesn't play very nicely with the ``ContentLoss``
            # and ``StyleLoss`` we insert below. So we replace with out-of-place
            # ones here.
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        model.add_module(name, layer)

        if name in content_layers:
            # add content loss:
            target = model(content_img).detach()
            content_loss = ContentLoss(target)
            model.add_module("content_loss_{}".format(i), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            style_loss = []
            for index, style_image in enumerate(style_imgs):
                target_feature = model(style_image).detach()
                style_loss_index = StyleLoss(target_feature)
                model.add_module("style_loss_{}_{}".format(i, index), style_loss_index)
                style_loss.append(style_loss_index)
            style_losses.append(style_loss)

    # now we trim off the layers after the last content and style losses
    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
            break

    model = model[:(i + 1)]

    return model, style_losses, content_losses

Next, we select the input image. You can use a copy of the content image or white noise.

InΒ [20]:
def get_input_optimizer(input_img):
    # this line to show that input is a parameter that requires a gradient
    optimizer = optim.LBFGS([input_img])
    return optimizer

Finally, we must define a function that performs the neural transfer. For each iteration of the networks, it is fed an updated input and computes new losses. We will run the backward methods of each loss module to dynamically compute their gradients. The optimizer requires a β€œclosure” function, which reevaluates the module and returns the loss.

We still have one final constraint to address. The network may try to optimize the input with values that exceed the 0 to 1 tensor range for the image. We can address this by correcting the input values to be between 0 to 1 each time the network is run.

InΒ [21]:
def run_style_transfer(cnn, normalization_mean, normalization_std,
                       content_img, style_imgs, style_coeffs, input_img, num_steps=300,
                       style_weight=1000000, content_weight=1):
    """Run the style transfer."""
    print('Building the style transfer model..')
    model, style_losses, content_losses = get_style_model_and_losses(cnn,
        normalization_mean, normalization_std, style_imgs, content_img)

    # We want to optimize the input and not the model parameters so we
    # update all the requires_grad fields accordingly
    input_img.requires_grad_(True)
    # We also put the model in evaluation mode, so that specific layers 
    # such as dropout or batch normalization layers behave correctly. 
    model.eval()
    model.requires_grad_(False)

    optimizer = get_input_optimizer(input_img)

    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:

        def closure():
            # correct the values of updated input image
            with torch.no_grad():
                input_img.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_img)
            style_score = 0
            content_score = 0

            for sl in style_losses:
                for index, style_loss in enumerate(sl):
                    style_score += style_coeffs[index] * style_loss.loss
            for cl in content_losses:
                content_score += cl.loss

            style_score *= style_weight
            content_score *= content_weight

            loss = style_score + content_score
            loss.backward()

            run[0] += 1
            # if run[0] % 50 == 0:
            #     print("run {}:".format(run))
            #     print('Style Loss : {:4f} Content Loss: {:4f}'.format(
            #         style_score.item(), content_score.item()))
            #     print()

            return style_score + content_score

        optimizer.step(closure)

    # a last correction...
    with torch.no_grad():
        input_img.clamp_(0, 1)

    return input_img

Finally, we can run the algorithm.

InΒ [22]:
style_imgs = [style_img_1, style_img_2]
style_coeffs = [0.5, 0.5]
input_img = content_img.clone()
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
                            content_img, style_imgs, style_coeffs, input_img)

plt.figure(figsize=(20, 10))

output = output.cpu().clone()
output = output.squeeze(0)
output = unloader(output)
output = transforms.Resize((original_dim[1], original_dim[0]))(output)

plt.title("Stylized Image")
plt.imshow(output)
# remove ticks and axis
plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
plt.savefig("output/" "stylized_Masterlayer_Event221_SetA" + ".png", dpi=300)
Building the style transfer model..
/home2/aman.atman/miniconda3/envs/tf2/lib/python3.8/site-packages/torch/utils/_device.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return func(*args, **kwargs)
Optimizing..
No description has been provided for this image

This approach is effective but not scalable and is impractical for the desired resolution. Johnson et al., in the "Perceptual Losses for Real-Time Style Transfer and Super-Resolution" paper, addressed this limitation by introducing a feedforward style transfer network. This network is trained to directly transform content into a stylized image in a single pass. However, it is not directly applicable to our problem as the network is trained on a single style

A LEARNED REPRESENTATION FOR ARTISTIC STYLEΒΆ

Dumoulin et al. describe a method that easily scales to $N$-styles while still being fast. A visual texture or style is conjectured to be spatially homogenous. It consists of repeated structural motifs whose minimal sufficient statistics are captured by lower order statistical measurements. The most important idea is that styles probably share some degree of computation. For example, different art may have similar paint strokes but differ in the color palette. They propose Conditional Instance Normalization (CIN), which allows all convolutional weights of a style transfer network to be shared across many styles. It is sufficient to tune parameters for an affine transformation after normalization for each style. Scaling and shifting are the only requirements to condition on a specific style, after normalizing the layer’s activations $x$.

$$z=\gamma_s\left(\frac{x-\mu}{\sigma}\right)+\beta_s$$

where $\mu$ and $\sigma$ are $x$'s mean and standard deviation across both the spatial axes.

This method is demonstrated to generalize across a diversity of artistic styles, reducing a painting to a point in an embedding space and permitting a user to explore new painting styles by arbitrarily combining the styles learned from individual paintings.

ArchitectureΒΆ

johnson arch

Image Transform Network is a feedforward style transfer network, which is trained to go from content to stylized image in one pass. VGG (pretrained) serves as loss function similar to the work of Gatys et al.

Image Transform NetworkΒΆ

style

Image Transform Network is a deep residual convolutional neural network that converts white noise image into stylized image. Instead of pooling layers which causes information loss, they use strided convolutions. Residual blocks allow for a deeper network. They use nearest neighbour upsampling instead of fractional convolution to prevent checkerboard patterns.

I am using the entire pretrained model, not just VGG because of computational and time constraints. It has been trained on the 'varied' set of paintings.

Sources - https://github.com/magenta/magenta/tree/main/magenta/models/image_stylization

InΒ [23]:
import matplotlib.pyplot as plt
images = ["edwin.jpg", "lautrec.jpg", "rouault.jpeg", "signac.jpg"]

fig, axs = plt.subplots(2, 2, figsize=(10, 10))
plt.tight_layout()
# set title

index = 0
for i in range(2):
    for j in range(2):
        img = image_loader("images/" + images[index])
        img = get_PIL_image(img)
        axs[i, j].imshow(img, interpolation='nearest')
        name = images[index].split(".")[0]
        name = name.capitalize()
        axs[i, j].set_title(name)
        axs[i, j].axis('off')
        index += 1

# add caption
No description has been provided for this image

Here are few of the paintings the model was trained on.

Next, define the utility functions for getting the stylized images from the pretrained model.

InΒ [24]:
import ast
import os
from skimage import io
import numpy as np
import tensorflow.compat.v1 as tf
from icecream import ic
import matplotlib.pyplot as plt

from image_stylization import image_utils, model, ops

import numpy as np


def create_style(n=10, distrib=None):
    if distrib is None:
        distrib = np.random.rand(n)
    distrib = distrib / np.sum(distrib)
    styles = {}
    for i in range(n):
        styles[i] = distrib[i]
    return styles



def _load_checkpoint(sess, checkpoint):
    """Loads a checkpoint file into the session."""
    model_saver = tf.train.Saver(tf.global_variables())
    checkpoint = os.path.expanduser(checkpoint)
    if tf.gfile.IsDirectory(checkpoint):
        checkpoint = tf.train.latest_checkpoint(checkpoint)
        tf.logging.info("loading latest checkpoint file: {}".format(checkpoint))
    model_saver.restore(sess, checkpoint)


def _describe_style(which_styles):
    """Returns a string describing a linear combination of styles."""

    def _format(v):
        formatted = str(int(round(v * 1000.0)))
        while len(formatted) < 3:
            formatted = "0" + formatted
        return formatted

    values = []
    for k in sorted(which_styles.keys()):
        values.append("%s_%s" % (k, _format(which_styles[k])))
    return "_".join(values)


def _style_mixture(which_styles, num_styles):
    """Returns a 1-D array mapping style indexes to weights."""
    if not isinstance(which_styles, dict):
        raise ValueError("Style mixture must be a dictionary.")
    mixture = np.zeros([num_styles], dtype=np.float32)
    for index in which_styles:
        mixture[index] = which_styles[index]
    return mixture



def _multiple_styles(input_image, which_styles, output_dir, checkpoint, output_basename, num_styles=10, alpha=1.0, asset_path=None):
    """Stylizes image into a linear combination of styles and writes to disk."""
    with tf.Graph().as_default(), tf.Session() as sess:
        mixture = _style_mixture(which_styles, num_styles)
        stylized_images = model.transform(
            input_image,
            alpha=alpha,
            normalizer_fn=ops.weighted_instance_norm,
            normalizer_params={
                "weights": tf.constant(mixture),
                "num_categories": num_styles,
                "center": True,
                "scale": True,
            },
        )
        _load_checkpoint(sess, checkpoint)

        stylized_image = stylized_images.eval()

        image = np.uint8(stylized_image * 255.0)
        
        if asset_path is None:
            image_utils.save_np_image(
                stylized_image,
                os.path.join(
                    output_dir,
                    "%s_%s.png" % (output_basename, _describe_style(which_styles)),
                ),
            )
        else:
            image_utils.save_np_image(
                stylized_image,
                asset_path
            )
        return np.squeeze(image, 0)
    
from ipywidgets import interact

def browse_images(varied_stylized):
    n = len(varied_stylized)
    def view_image(image):
        plt.figure(figsize=(20, 10))
        plt.imshow(varied_stylized[image],interpolation='nearest')
        plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
        plt.show()
        plt.pause(0.001)
    interact(view_image, image=(0,n-1))
InΒ [25]:
tf.disable_v2_behavior()
input_image = "images/Masterlayer_Event221_SetA.png"
checkpoint = "image_stylization/multistyle-pastiche-generator-varied.ckpt"
output_basename = "all_monet_styles"
output_dir = "output"
content_image = np.expand_dims(
    image_utils.load_np_image(os.path.expanduser(input_image)), 0
)
 

output_dir = os.path.expanduser(output_dir)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
num_styles = 32
which_styles =  create_style(32)
_multiple_styles(content_image, which_styles, output_dir, checkpoint, output_basename, num_styles);

varied_stylized = []
for i in range(32):
    distrib = np.zeros(32)
    distrib[i] = 1
    which_styles =  create_style(32, distrib)
    varied_stylized.append(_multiple_styles(content_image, which_styles, output_dir, checkpoint, output_basename, num_styles))

browse_images(varied_stylized)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
interactive(children=(IntSlider(value=15, description='image', max=31), Output()), _dom_classes=('widget-inter…

We can browse through the styles using the slider.

Now, suppose we want to narrow down to a few styles of interest. I am choosing the following without loss of generality.

InΒ [26]:
style_indices = [1, 31, 23, 8]

fig, axs = plt.subplots(2, 2, figsize=(15, 15))
plt.tight_layout()


for i in range(4):
    axs[i//2, i%2].imshow(varied_stylized[style_indices[i]],interpolation='nearest')
    axs[i//2, i%2].axis('off')
No description has been provided for this image

How to find the right mix?ΒΆ

Deciding what weights to give any individual style can become intractable. Very few in the space of weight vectors may generate useful and pleasant stylized images. We can only recognize the desired behavior, but not necessarily demonstrate it. It motivates the application of reinforcement learning with human feedback.

Reinforcement Learning with Human Feedback (RLHF)ΒΆ

RLHF was used by OpenAI to significantly improve their GPT model. Inspired by this, I want to test whether this idea may be relevant for our problem as well. We could leverage human feedback to learn a reward model which assigns a score to any particular weights distribution vector. This may become useful to find the optimal set of style weights.

RLHF is very flexible and there are multiple ways to implment it. I choose a very simple algorithm and consider the style tranfer network as a black box. I don't want to delve into model training because of computational constraints.

Let us first see how we can use the human feedback to learn the reward model.

Steps

  • Create a UI to present pair of images to the human labeler. They will choose a winner from them. Use it to construct a dataset in the form of $(y_w, y_l)$ where $y_w$ corresponds to the style weights of the winning image and $v_l$ is the style weights of the losing image.
  • Learn the reward model $r_\phi$

rlhf algo

Source : https://huyenchip.com/2023/05/02/rlhf.html#phase_3_rlhf

After learning the $r_\phi$, we can use the standard RL algorithms to find the optimal weights. I will give a brief overview of this.

RL agent interacts with the world (called it’s environment) by using a policy to choose at every time step from a set of actions. The environment responds by causing the agent to transition to its next state, and providing a reward attributed to the last action from the prior state.

rl

Policy is a behavious function that maps states to actions. We want to learn the optimal policy i.e = the distribution over actions that gives the highest reward.

$$ a=\pi(s) $$

I am modelling the problem by making the following assumptions -

  1. Action space $\mathcal{A}$ is continuous and multi-variate i.e $a \in \mathbb{R}^N$ where $N$ is the number of styles. It is a distribution over the styles.

  2. There is single state - the initial content image. We want to learn optimal policy which gives the most useful and pleasing stylized image.

We can directly use any of the standard RL algorithms.

NoteΒΆ

I will not be implementing the RL algorithm to find the optimal weights because of the practical issues. I am not sure whether RL will be so useful given there is only a single state. Furthermore, these are notoriously data hungry and unstable. We simply can't generate huge datasets quickly. Though, Direct Preference optimization can come handy here.

I will only be implementing the reward model and train it on a toy dataset.

InΒ [27]:
# construct a toy dataset
import numpy as np
import json
dataset = []
num_chosen_styles = 4
np.random.seed(0)
N = 10

for i in range(N):
    # sample two random vectors of length num_chosen_syles
    a = np.random.randn(num_chosen_styles)
    a = np.absolute(a)
    b = np.random.randn(num_chosen_styles)
    b = np.absolute(b)
    dataset += [(a, b)]
RLHF UIΒΆ

The images load slowly, so please wait for the Choose message before clicking the buttons.

InΒ [28]:
# 
import ipywidgets as widgets
from IPython.display import display
from PIL import Image
import matplotlib.pyplot as plt
import time

index = 0
def create_style_from_choices(style_weights):
    distrib = np.zeros(32)
    for i in range(num_chosen_styles):
        distrib[style_indices[i]] = style_weights[i]
    return create_style(32, distrib)

def get_image(style_weights):
    style = create_style_from_choices(style_weights)
    image = _multiple_styles(content_image, style, output_dir, checkpoint, output_basename, num_styles)
    return Image.fromarray(image)._repr_png_()

responses = []

# Load the images
size = 420
image_widget1 = widgets.Image(value=get_image(dataset[index][0]), format='png', width=size, height=size)
image_widget2 = widgets.Image(value=get_image(dataset[index][1]), format='png', width=size, height=size)
text_widget = widgets.Text(value='Choose')
images_hbox = widgets.HBox([image_widget1, image_widget2])

# button to mark choice among the two images, switch to next set of images on each press
button1 = widgets.Button(description='Left')
button2 = widgets.Button(description='Right')
# update the images
def update_images(button):
    # log the choice
    global index 
    if index >= len(dataset):
        print("over")
        return
    text_widget.value = 'waiting...'
    responses.append(button.description)
    image_widget1.value = get_image(dataset[index][0])
    image_widget2.value = get_image(dataset[index][1])
    index += 1
    time.sleep(5)
    text_widget.value = 'Choose'


button1.on_click(update_images)
button2.on_click(update_images)

# Display the images
display(images_hbox)
display(button1, button2, text_widget)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x08\x00\x00\x00\x06\x00\x08\x02\x00\x…
Button(description='Left', style=ButtonStyle())
Button(description='Right', style=ButtonStyle())
Text(value='Choose')

I use a simple neural net to implement Reward Model.

InΒ [29]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class RLHFLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.LogSigmoid()
    def forward(self, losing_score, winning_score):
        return -self.loss(winning_score - losing_score)
    
class RewardModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.hidden_dim = input_dim
        self.fc1 = nn.Linear(input_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, 1)
        self.apply(self.weights_init)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    

num_epochs = 10
reward_model = RewardModel(num_chosen_styles)


optimizer = torch.optim.Adam(reward_model.parameters())
loss_fn = RLHFLoss()

loss_values = []

for epoch in range(num_epochs):
    total_loss = 0
    for losing_weights, winning_weights in dataset:
        optimizer.zero_grad()
        losing_score = reward_model(torch.from_numpy(losing_weights).float().cuda())
        winning_score = reward_model(torch.from_numpy(winning_weights).float().cuda())

        loss = loss_fn(losing_score, winning_score)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()
    
    loss_values.append(total_loss)

plt.plot(loss_values)
plt.title('Loss over time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
No description has been provided for this image

Now we plot each uniform combination ($4 C 2 $) of the $4$ styles.

InΒ [30]:
weights_store = []
for i in range(4):
    for j in range(i + 1, 4):
        a = np.zeros(num_chosen_styles)
        a[i] = a[j] = 1
        weights_store.append(a)

fig, axs = plt.subplots(2, 3, figsize=(20, 10))

for i in range(2):
    for j in range(3):
        style = create_style_from_choices(weights_store[i*3 + j])
        image = _multiple_styles(content_image, style, output_dir, checkpoint, output_basename, num_styles)
        axs[i, j].imshow(image, interpolation='nearest')
        axs[i, j].axis('off')
        axs[i, j].set_title(f"Style {i*3 + j} score { reward_model(torch.from_numpy(weights_store[i*3 + j]).float().cuda()).item()}")
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
No description has been provided for this image

Now we can run the algorithm on all assets.

InΒ [31]:
import os
import numpy as np

assets_dir = "assets/Room_SetA/"
weights_store = [[1, 1, 0, 0]]
num_chosen_styles = 4

for image_name in os.listdir(assets_dir):
    if os.path.isdir(os.path.join(assets_dir, image_name)):
        continue
    print(image_name)
    content_image = np.expand_dims(
        image_utils.load_np_image(os.path.join(assets_dir, image_name)), 0
    )
    print(content_image.shape)
    style = create_style_from_choices(weights_store[0])
    image = _multiple_styles(content_image, style, output_dir, checkpoint, output_basename, num_styles, asset_path="assets/outputs/"+image_name)
(1, 327, 296, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 481, 844, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 222, 1208, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 97, 380, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 163, 1236, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 95, 490, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 1175, 620, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 100, 153, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 170, 1232, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 265, 777, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 926, 2048, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 340, 697, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 293, 263, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 424, 676, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 255, 1005, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 186, 398, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 509, 710, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 207, 293, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 172, 289, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 94, 234, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 1536, 2048, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 176, 275, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 224, 634, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 976, 2048, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 189, 275, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 234, 932, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 29, 1089, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 278, 356, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 324, 300, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 65, 695, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 613, 2048, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 126, 297, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 193, 217, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
(1, 219, 352, 3)
INFO:tensorflow:Restoring parameters from image_stylization/multistyle-pastiche-generator-varied.ckpt
No description has been provided for this image
LimitationsΒΆ
  1. We are indirectly learning the stylized image using the individual layer embeddings. It would be faster to directly learn from the style and content images.
  2. This method is not leveraging the current generative methods which are bound to give much higher quality images.